# 6.3 模型微调 - timm 除了使用`torchvision.models`进行预训练以外,还有一个常见的预训练模型库,叫做`timm`,这个库是由Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下两个链接。 - Github链接:https://github.com/rwightman/pytorch-image-models - 官网链接:https://fastai.github.io/timmdocs/ https://rwightman.github.io/pytorch-image-models/ ## 6.3.1 timm的安装 关于timm的安装,我们可以选择以下两种方式进行: 1. 通过pip安装 ```shell pip install timm ``` 2. 通过源码编译安装 ```shell git clone https://github.com/rwightman/pytorch-image-models cd pytorch-image-models && pip install -e . ``` ## 6.3.2 如何查看预训练模型种类 1. 查看timm提供的预训练模型 截止到2022.3.27日为止,timm提供的预训练模型已经达到了592个,我们可以通过`timm.list_models()`方法查看timm提供的预训练模型(注:本章测试代码均是在jupyter notebook上进行) ```python import timm avail_pretrained_models = timm.list_models(pretrained=True) len(avail_pretrained_models) ``` ```shell 592 ``` 2. 查看特定模型的所有种类 每一种系列可能对应着不同方案的模型,比如Resnet系列就包括了ResNet18,50,101等模型,我们可以在`timm.list_models()`传入想查询的模型名称(模糊查询),比如我们想查询densenet系列的所有模型。 ```python all_densnet_models = timm.list_models("*densenet*") all_densnet_models ``` 我们发现以列表的形式返回了所有densenet系列的所有模型。 ```shell ['densenet121', 'densenet121d', 'densenet161', 'densenet169', 'densenet201', 'densenet264', 'densenet264d_iabn', 'densenetblur121d', 'tv_densenet121'] ``` 3. 查看模型的具体参数 当我们想查看下模型的具体参数的时候,我们可以通过访问模型的`default_cfg`属性来进行查看,具体操作如下 ```python model = timm.create_model('resnet34',num_classes=10,pretrained=True) model.default_cfg ``` ```python {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'conv1', 'classifier': 'fc', 'architecture': 'resnet34'} ``` 除此之外,我们可以通过访问这个[链接](https://rwightman.github.io/pytorch-image-models/results/) 查看提供的预训练模型的准确度等信息。 ## 6.3.3 使用和修改预训练模型 在得到我们想要使用的预训练模型后,我们可以通过`timm.create_model()`的方法来进行模型的创建,我们可以通过传入参数`pretrained=True`,来使用预训练模型。同样的,我们也可以使用跟torchvision里面的模型一样的方法查看模型的参数,类型/ ```python import timm import torch model = timm.create_model('resnet34',pretrained=True) x = torch.randn(1,3,224,224) output = model(x) output.shape ``` ```shell torch.Size([1, 1000]) ``` - 查看某一层模型参数(以第一层卷积为例) ```python model = timm.create_model('resnet34',pretrained=True) list(dict(model.named_children())['conv1'].parameters()) ``` ```python [Parameter containing: tensor([[[[-2.9398e-02, -3.6421e-02, -2.8832e-02, ..., -1.8349e-02, -6.9210e-03, 1.2127e-02], [-3.6199e-02, -6.0810e-02, -5.3891e-02, ..., -4.2744e-02, -7.3169e-03, -1.1834e-02], ... [ 8.4563e-03, -1.7099e-02, -1.2176e-03, ..., 7.0081e-02, 2.9756e-02, -4.1400e-03]]]], requires_grad=True)] ``` - 修改模型(将1000类改为10类输出) ```python model = timm.create_model('resnet34',num_classes=10,pretrained=True) x = torch.randn(1,3,224,224) output = model(x) output.shape ``` ```python torch.Size([1, 10]) ``` - 改变输入通道数(比如我们传入的图片是单通道的,但是模型需要的是三通道图片) 我们可以通过添加`in_chans=1`来改变 ```python model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1) x = torch.randn(1,1,224,224) output = model(x) ``` ## 6.3.4 模型的保存 timm库所创建的模型是`torch.model`的子类,我们可以直接使用torch库中内置的模型参数保存和加载的方法,具体操作如下方代码所示 ```python torch.save(model.state_dict(),'./checkpoint/timm_model.pth') model.load_state_dict(torch.load('./checkpoint/timm_model.pth')) ``` ## 参考材料 1. https://www.aiuai.cn/aifarm1967.html 2. https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055 3. https://chowdera.com/2022/03/202203170834122729.html